package web.server.ioc;

import io.netty.util.internal.StringUtil;
import lombok.SneakyThrows;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import web.server.ServerConfig;
import web.server.annotation.Func;
import web.server.annotation.RequestMethod;
import web.server.utils.JsonHelper;

import java.io.File;
import java.lang.reflect.Constructor;
import java.lang.reflect.Method;
import java.net.URL;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;

public class ApiStore {

    private static final Logger log = LoggerFactory.getLogger(ApiStore.class);

    private static final Map<String, Api> PostApi = new ConcurrentHashMap<>();
    private static final Map<String, Api> GetApi = new ConcurrentHashMap<>();

    public static Object invokeApi(String method, String uri, Object[] args) throws Exception {
        Api api = null;
        if ("GET".equals(method)) {
            api = GetApi.get(uri);
        } else if ("POST".equals(method)) {
            api = PostApi.get(uri);
        }
        if (api == null) {
            return "";
        }
        Class<?> parameterType = api.method.getParameterTypes()[0];
        Object[] ojs = new Object[]{castParam(JsonHelper.jsonObjectToJavaObject((String) args[0], parameterType), parameterType)};
        return api.invoke(ojs);
    }

    public static <T> T castParam(Object object, Class<T> clazz) {
        if (object == null) {
            return null;
        }
        if (clazz.isAssignableFrom(object.getClass())) {
            return (T) object;
        }
        try {
            Constructor<T> constructor = clazz.getConstructor(String.class);
            return constructor.newInstance(object.toString());
        } catch (Exception e) {
            throw new IllegalArgumentException("Cannot cast object to " + clazz.getName(), e);
        }
    }

    @SneakyThrows
    public static void register(ServerConfig serverConfig) {

        List<String> packagePaths = serverConfig.getApis();
        System.out.println(packagePaths);
        log.info("scan package {}", packagePaths);

        for (String packagePath : packagePaths) {
            // 获取当前包路径下所有的类文件
            List<Class<?>> classes = findClasses(packagePath);
            for (Class<?> aClass : classes) {
                // 遍历类中所有的方法判断是否有@Func注解
                Func classFunc = aClass.getAnnotation(Func.class);
                String baseUrl = "";
                if (classFunc != null && !StringUtil.isNullOrEmpty(classFunc.path())) {
                    baseUrl = classFunc.path();
                }
                for (Method method : aClass.getMethods()) {
                    Func methodFunc = method.getAnnotation(Func.class);
                    if (methodFunc != null) {
                        String funcUrl = methodFunc.path();
                        RequestMethod[] requestMethods = methodFunc.method();
                        if (StringUtil.isNullOrEmpty(funcUrl)) {
                            funcUrl =  method.getName();
                        }
                        if (requestMethods.length == 0) {
                            requestMethods = new RequestMethod[1];
                            requestMethods[0] = RequestMethod.GET;
                        }
                        String url = getUrl(baseUrl, funcUrl);
                        String registerApi = String.format("扫描到API函数 class [%s] method [%s] httpUrl [%s] requestMethods %s", aClass, method.getName(), url, Arrays.toString(requestMethods));
                        log.info(registerApi);
                        for (RequestMethod requestMethod : requestMethods) {
                            method.setAccessible(true);

                            Api api = new Api(method.getName(), method, aClass.getDeclaredConstructor().newInstance());
                            if (requestMethod == RequestMethod.GET) {
                                GetApi.put(url, api);
                            } else {
                                PostApi.put(url, api);
                            }
                        }
                    }
                }
            }
        }
    }

    public static String getUrl(String baseUrl, String funcUrl) {

        if (!baseUrl.startsWith("/")) {
            baseUrl = "/" + baseUrl;
        }
        if (baseUrl.endsWith("/")) {
            baseUrl = baseUrl.substring(0, baseUrl.length() -1);
        }
        if (!funcUrl.startsWith("/")) {
            funcUrl = "/" + funcUrl;
        }
        return baseUrl + funcUrl;
    }

    public static List<Class<?>> findClasses(String packageName) throws Throwable {
        String path = packageName.replace('.', '/');
        Enumeration<URL> resources = Thread.currentThread().getContextClassLoader().getResources(path);

        List<Class<?>> classes = new ArrayList<>();
        while (resources.hasMoreElements()) {
            URL resource = resources.nextElement();
            File directory = new File(resource.toURI());
            classes.addAll(findClasses(directory, packageName));
        }

        return classes;
    }

    private static List<Class<?>> findClasses(File directory, String packageName) throws Throwable {
        List<Class<?>> classes = new ArrayList<>();
        if (!directory.exists() || !directory.isDirectory()) {
            return classes;
        }

        File[] files = directory.listFiles();
        if (files == null) {
            return classes;
        }

        for (File file : files) {
            if (file.isDirectory()) {
                classes.addAll(findClasses(file, packageName + "." + file.getName()));
            } else if (file.getName().endsWith(".class")) {
                String className = packageName + '.' + file.getName().substring(0, file.getName().length() - 6);
                try {
                    Class<?> clazz = Class.forName(className);
                    classes.add(clazz);
                } catch (ClassNotFoundException e) {
                    // 处理异常
                    throw  e.getException();
                }
            }
        }
        return classes;
    }
}
